{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0YC2vlsGs5tg"
   },
   "source": [
    "# Semantic Segmentation with KerasCV\n",
    "\n",
    "**Author:** [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli), [Ian Stenbit](https://github.com/ianstenbit)<br>\n",
    "**Date created:** 2023/08/22<br>\n",
    "**Last modified:** 2023/08/24<br>\n",
    "**Description:** Train and use DeepLabv3+ segmentation model with KerasCV."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zEUpBnaGs5th"
   },
   "source": [
    "![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_intermediate.png)\n",
    "\n",
    "## Background\n",
    "Semantic segmentation is a type of computer vision task that involves assigning a\n",
    "class label such as person, bike, or background to each individual pixel of an\n",
    "image, effectively dividing the image into regions that correspond to different\n",
    "fobject classes or categories.\n",
    "\n",
    "![](https://miro.medium.com/v2/resize:fit:4800/format:webp/1*z6ch-2BliDGLIHpOPFY_Sw.png)\n",
    "\n",
    "\n",
    "\n",
    "KerasCV offers the DeepLabv3+ model developed by Google for semantic\n",
    "segmentation. This guide demonstrates how to finetune and use DeepLabv3+ model for\n",
    "image semantic segmentaion with KerasCV. Its architecture that combines atrous convolutions,\n",
    "contextual information aggregation, and powerful backbones to achieve accurate and\n",
    "detailed semantic segmentation. The DeepLabv3+ model has been shown to achieve\n",
    "state-of-the-art results on a variety of image segmentation benchmarks.\n",
    "\n",
    "### References\n",
    "[Encoder-Decoder with Atrous Separable Convolution for Semantic Image\n",
    "Segmentation](https://arxiv.org/abs/1802.02611)<br>\n",
    "[Rethinking Atrous Convolution for Semantic Image\n",
    "Segmentation](https://arxiv.org/abs/1706.05587)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vgm-Z4Rus5ti"
   },
   "source": [
    "## Setup and Imports\n",
    "\n",
    "Let's install the dependencies and import the necessary modules.\n",
    "\n",
    "To run this tutorial, you will need to install the following packages:\n",
    "\n",
    "* `keras-cv`\n",
    "* `keras-core`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "89IDcffts5ti"
   },
   "outputs": [],
   "source": [
    "!pip install -q --upgrade keras-cv\n",
    "!pip install -q --upgrade keras # Upgrade to Keras 3."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aT_RCAG3s5tj"
   },
   "source": [
    "After installing `keras-core` and `keras-cv`, set the backend for `keras-core`.\n",
    "This guide can be run with any backend (Tensorflow, JAX, PyTorch).\n",
    "\n",
    "```\n",
    "import os\n",
    "\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xRyHrUEDs5tj"
   },
   "outputs": [],
   "source": [
    "import keras\n",
    "from keras import ops\n",
    "\n",
    "import keras_cv\n",
    "import numpy as np\n",
    "\n",
    "from keras_cv.datasets.pascal_voc.segmentation import load as load_voc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "98f7WhdZs5tj"
   },
   "source": [
    "## Perform semantic segmentation with a pretrained DeepLabv3+ model\n",
    "\n",
    "The highest level API in the KerasCV semantic segmentation API is the `keras_cv.models`\n",
    "API. This API includes fully pretrained semantic segmentation models, such as\n",
    "`keras_cv.models.DeepLabV3Plus`.\n",
    "\n",
    "Let's get started by constructing a DeepLabv3+ pretrained on the pascalvoc dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "M97l1P2Ms5tj"
   },
   "outputs": [],
   "source": [
    "model = keras_cv.models.DeepLabV3Plus.from_preset(\n",
    "    \"deeplab_v3_plus_resnet50_pascalvoc\",\n",
    "    num_classes=21,\n",
    "    input_shape=[512, 512, 3],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9lUDEOr4s5tk"
   },
   "source": [
    "Let us visualize the results of this pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nUzsOeyqs5tk"
   },
   "outputs": [],
   "source": [
    "filepath = keras.utils.get_file(origin=\"https://i.imgur.com/gCNcJJI.jpg\")\n",
    "image = keras.utils.load_img(filepath)\n",
    "\n",
    "resize = keras_cv.layers.Resizing(height=512, width=512)\n",
    "image = resize(image)\n",
    "image = keras.ops.expand_dims(np.array(image), axis=0)\n",
    "preds = ops.expand_dims(ops.argmax(model(image), axis=-1), axis=-1)\n",
    "keras_cv.visualization.plot_segmentation_mask_gallery(\n",
    "    image,\n",
    "    value_range=(0, 255),\n",
    "    num_classes=1,\n",
    "    y_true=None,\n",
    "    y_pred=preds,\n",
    "    scale=3,\n",
    "    rows=1,\n",
    "    cols=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vyqoiZcis5tk"
   },
   "source": [
    "## Train a custom semantic segmentation model\n",
    "In this guide, we'll assemble a full training pipeline for a KerasCV DeepLabV3 semantic\n",
    "segmentation model. This includes data loading, augmentation, training, metric\n",
    "evaluation, and inference!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bLz1WdoZs5tk"
   },
   "source": [
    "## Download the data\n",
    "\n",
    "We download\n",
    "[Pascal VOC dataset](https://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz)\n",
    "with KerasCV datasets and split them into train dataset `train_ds` and `eval_ds`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nfB7NSHHs5tk"
   },
   "outputs": [],
   "source": [
    "train_ds = load_voc(split=\"sbd_train\")\n",
    "eval_ds = load_voc(split=\"sbd_eval\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fFF-YE1fs5tl"
   },
   "source": [
    "## Preprocess the data\n",
    "\n",
    "The `preprocess_tfds_inputs` utility function preprocesses the inputs to a dictionary of\n",
    "`images` and `segmentation_masks`. The images and segmentation masks are resized to\n",
    "512x512. The resulting dataset is then batched into groups of 4 image and segmentation\n",
    "mask pairs.\n",
    "\n",
    "A batch of this preprocessed input training data can be visualized using the\n",
    "`keras_cv.visualization.plot_segmentation_mask_gallery` function. This function takes a\n",
    "batch of images and segmentation masks as input and displays them in a grid."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mD0Y8iMLs5tl"
   },
   "outputs": [],
   "source": [
    "def preprocess_tfds_inputs(inputs):\n",
    "    def unpackage_tfds_inputs(tfds_inputs):\n",
    "        return {\n",
    "            \"images\": tfds_inputs[\"image\"],\n",
    "            \"segmentation_masks\": tfds_inputs[\"class_segmentation\"],\n",
    "        }\n",
    "\n",
    "    outputs = inputs.map(unpackage_tfds_inputs)\n",
    "    outputs = outputs.map(keras_cv.layers.Resizing(height=512, width=512))\n",
    "    outputs = outputs.batch(4, drop_remainder=True)\n",
    "    return outputs\n",
    "\n",
    "\n",
    "train_ds = preprocess_tfds_inputs(train_ds)\n",
    "batch = train_ds.take(1).get_single_element()\n",
    "keras_cv.visualization.plot_segmentation_mask_gallery(\n",
    "    batch[\"images\"],\n",
    "    value_range=(0, 255),\n",
    "    num_classes=21,  # The number of classes for the oxford iiit pet dataset. The VOC dataset also includes 1 class for the background.\n",
    "    y_true=batch[\"segmentation_masks\"],\n",
    "    scale=3,\n",
    "    rows=2,\n",
    "    cols=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7NIGx0zHs5tl"
   },
   "source": [
    "The preprocessing is applied to the evaluation dataset `eval_ds`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "t0264OIJs5tl"
   },
   "outputs": [],
   "source": [
    "eval_ds = preprocess_tfds_inputs(eval_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KfPbd-TTs5tl"
   },
   "source": [
    "## Data Augmentation\n",
    "\n",
    "KerasCV provides a variety of image augmentation options. In this example, we will use\n",
    "the `RandomFlip` augmentation to augment the training dataset. The `RandomFlip`\n",
    "augmentation randomly flips the images in the training dataset horizontally or\n",
    "vertically. This can help to improve the model's robustness to changes in the orientation\n",
    "of the objects in the images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "W_0Ei44ls5tl"
   },
   "outputs": [],
   "source": [
    "train_ds = train_ds.map(keras_cv.layers.RandomFlip())\n",
    "batch = train_ds.take(1).get_single_element()\n",
    "\n",
    "keras_cv.visualization.plot_segmentation_mask_gallery(\n",
    "    batch[\"images\"],\n",
    "    value_range=(0, 255),\n",
    "    num_classes=21,\n",
    "    y_true=batch[\"segmentation_masks\"],\n",
    "    scale=3,\n",
    "    rows=2,\n",
    "    cols=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "M99ecGY4s5tm"
   },
   "source": [
    "## Model Configuration\n",
    "\n",
    "Please feel free to modify the configurations for model training and note how the\n",
    "training results changes. This is an great exercise to get a better understanding of the\n",
    "training pipeline.\n",
    "\n",
    "The learning rate schedule is used by the optimizer to calculate the learning rate for\n",
    "each epoch. The optimizer then uses the learning rate to update the weights of the model.\n",
    "In this case, the learning rate schedule uses a cosine decay function. A cosine decay\n",
    "function starts high and then decreases over time, eventually reaching zero. The\n",
    "cardinality of the VOC dataset is 2124 with a batch size of 4. The dataset cardinality\n",
    "is important for learning rate decay because it determines how many steps the model\n",
    "will train for. The initial learning rate is proportional to 0.007 and the decay\n",
    "steps are 2124. This means that the learning rate will start at `INITIAL_LR` and then\n",
    "decrease to zero over 2124 steps.\n",
    "![png](/img/guides/semantic_segmentation_deeplab_v3_plus/learning_rate_schedule.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4zqr0oF5s5tm"
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 4\n",
    "INITIAL_LR = 0.007 * BATCH_SIZE / 16\n",
    "EPOCHS = 1\n",
    "NUM_CLASSES = 21\n",
    "learning_rate = keras.optimizers.schedules.CosineDecay(\n",
    "    INITIAL_LR,\n",
    "    decay_steps=EPOCHS * 2124,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ES4SUSims5tm"
   },
   "source": [
    "We instantiate a DeepLabV3+ model with a ResNet50 backbone pretrained on ImageNet classification:\n",
    "`resnet50_v2_imagenet` pre-trained weights will be used as the backbone feature\n",
    "extractor for the DeepLabV3Plus model. The `num_classes` parameter specifies the number of\n",
    "classes that the model will be trained to segment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LoNY90Cgs5tm"
   },
   "outputs": [],
   "source": [
    "model = keras_cv.models.DeepLabV3Plus.from_preset(\n",
    "    \"resnet50_v2_imagenet\", num_classes=NUM_CLASSES\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wlwA_LTUs5tm"
   },
   "source": [
    "## Compile the model\n",
    "\n",
    "The model.compile() function sets up the training process for the model. It defines the\n",
    "- optimization algorithm - Stochastic Gradient Descent (SGD)\n",
    "- the loss function - categorical cross-entropy\n",
    "- the evaluation metrics - Mean IoU and categorical accuracy\n",
    "\n",
    "Semantic segmentation evaluation metrics:\n",
    "\n",
    "Mean Intersection over Union (MeanIoU):\n",
    "MeanIoU measures how well a semantic segmentation model accurately identifies\n",
    "and delineates different objects or regions in an image. It calculates the\n",
    "overlap between predicted and actual object boundaries, providing a score\n",
    "between 0 and 1, where 1 represents a perfect match.\n",
    "\n",
    "Categorical Accuracy:\n",
    "Categorical Accuracy measures the proportion of correctly classified pixels in\n",
    "an image. It gives a simple percentage indicating how accurately the model\n",
    "predicts the categories of pixels in the entire image.\n",
    "\n",
    "In essence, MeanIoU emphasizes the accuracy of identifying specific object\n",
    "boundaries, while Categorical Accuracy gives a broad overview of overall\n",
    "pixel-level correctness."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uM-Im0Mjs5tn"
   },
   "outputs": [],
   "source": [
    "model.compile(\n",
    "    optimizer=keras.optimizers.SGD(\n",
    "        learning_rate=learning_rate,\n",
    "        weight_decay=0.0001,\n",
    "        momentum=0.9,\n",
    "        clipnorm=10.0,\n",
    "    ),\n",
    "    loss=keras.losses.CategoricalCrossentropy(from_logits=False),\n",
    "    metrics=[\n",
    "        keras.metrics.MeanIoU(\n",
    "            num_classes=NUM_CLASSES, sparse_y_true=False, sparse_y_pred=False\n",
    "        ),\n",
    "        keras.metrics.CategoricalAccuracy(),\n",
    "    ],\n",
    ")\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Buh6A_1fs5tn"
   },
   "source": [
    "The utility function `dict_to_tuple` effectively transforms the dictionaries of training\n",
    "and validation datasets into tuples of images and one-hot encoded segmentation masks,\n",
    "which is used during training and evaluation of the DeepLabv3+ model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kOLcpKLbs5tn"
   },
   "outputs": [],
   "source": [
    "def dict_to_tuple(x):\n",
    "    import tensorflow as tf\n",
    "\n",
    "    return x[\"images\"], tf.one_hot(\n",
    "        tf.cast(tf.squeeze(x[\"segmentation_masks\"], axis=-1), \"int32\"), 21\n",
    "    )\n",
    "\n",
    "\n",
    "train_ds = train_ds.map(dict_to_tuple)\n",
    "eval_ds = eval_ds.map(dict_to_tuple)\n",
    "\n",
    "model.fit(train_ds, validation_data=eval_ds, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "r8ZSZmtPs5tn"
   },
   "source": [
    "## Predictions with trained model\n",
    "Now that the model training of DeepLabv3+ has completed, let's test it by making\n",
    "predications\n",
    "on a few sample images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RG07dyEUs5tn"
   },
   "outputs": [],
   "source": [
    "test_ds = load_voc(split=\"sbd_eval\")\n",
    "test_ds = preprocess_tfds_inputs(test_ds)\n",
    "\n",
    "images, masks = next(iter(train_ds.take(1)))\n",
    "images = ops.convert_to_tensor(images)\n",
    "masks = ops.convert_to_tensor(masks)\n",
    "preds = ops.expand_dims(ops.argmax(model(images), axis=-1), axis=-1)\n",
    "masks = ops.expand_dims(ops.argmax(masks, axis=-1), axis=-1)\n",
    "\n",
    "keras_cv.visualization.plot_segmentation_mask_gallery(\n",
    "    images,\n",
    "    value_range=(0, 255),\n",
    "    num_classes=21,\n",
    "    y_true=masks,\n",
    "    y_pred=preds,\n",
    "    scale=3,\n",
    "    rows=1,\n",
    "    cols=4,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "loWDjb1_s5tn"
   },
   "source": [
    "Here are some additional tips for using the KerasCV DeepLabv3+ model:\n",
    "\n",
    "- The model can be trained on a variety of datasets, including the COCO dataset, the\n",
    "PASCAL VOC dataset, and the Cityscapes dataset.\n",
    "- The model can be fine-tuned on a custom dataset to improve its performance on a\n",
    "specific task.\n",
    "- The model can be used to perform real-time inference on images.\n",
    "- Also, try out KerasCV's SegFormer model `keras_cv.models.segmentation.SegFormer`. The\n",
    "SegFormer model is a newer model that has been shown to achieve state-of-the-art results\n",
    "on a variety of image segmentation benchmarks. It is based on the Swin Transformer\n",
    "architecture, and it is more efficient and accurate than previous image segmentation\n",
    "models."
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "semantic_segmentation_deeplab_v3_plus",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}